pytorch实战:详解查准率(Precision)、查全率(Recall)与F1

您所在的位置:网站首页 pytorch 分割模型代码 pytorch实战:详解查准率(Precision)、查全率(Recall)与F1

pytorch实战:详解查准率(Precision)、查全率(Recall)与F1

2023-10-02 09:39| 来源: 网络整理| 查看: 265

pytorch实战:详解查准率(Precision)、查全率(Recall)与F1 1、概述

本文首先介绍了机器学习分类问题的性能指标查准率(Precision)、查全率(Recall)与F1度量,阐述了多分类问题中的混淆矩阵及各项性能指标的计算方法,然后介绍了PyTorch中scatter函数的使用方法,借助该函数实现了对Precision、Recall、F1及正确率的计算,并对实现过程进行了解释。

观前提示:阅读本文需要你对机器学习与PyTorch框架具有一定的了解。

Tips:如果你只是想利用PyTorch计算查准率(Precision)、查全率(Recall)、F1这几个指标,不想深入了解,请直接跳到第3部分copy代码使用即可。

2、查准率(Precision)、查全率(Recall)与F1 2.1、二分类问题

对于一个二分类(正例与反例)问题,其分类结果的混淆矩阵(Confusion Matrix)如下:

预测的正例预测的反例真实的正例TP(真正例)FN(假反例)真实的反例FP(假正例)TN(真 反例)

则查准率P定义为: P = T P T P + F P P=\frac{TP}{TP+FP} P=TP+FPTP​ 查全率R定义为: R = T P T P + F N R=\frac{TP}{TP+FN} R=TP+FNTP​ 可见,查准率与查全率是一对相互矛盾的量,前者表示的是预测的正例真的是正例的概率,而后者表示的是将真正的正例预测为正例的概率。听上去有点绕,通俗地讲,假设这个分类问题是从一批西瓜中辨别哪些是好瓜哪些是不好的瓜,查准率高则意味着你挑选的好瓜大概率真的是好瓜,但是由于你选瓜的标准比较高,这意味着你也错失了一些好瓜(宁缺毋滥);而查全率高则意味着你能选到大部分的好瓜,但是由于你为了挑选到尽可能多的好瓜,降低了选瓜的标准,这样许多不太好的瓜也被当成了好瓜选进来。通过调整你选瓜的“门槛”,就可以调整查准率与查全率,即:“门槛”高,则查准率高而查全率低;“门槛”低,则查准率低而查全率高。通常,查准率与查全率不可兼得,在不同的任务中,对查准率与查全率的的重视程度也会有所不同。这时,我们就需要一个综合考虑查准率与查全率的性能指标了,比如 F β F_{\beta} Fβ​,该值定义为: F β = ( 1 + β 2 ) × P × R ( β 2 × P ) + R F_{\beta}=\frac{(1+{\beta}^2) \times P \times R}{({\beta}^2 \times P)+R} Fβ​=(β2×P)+R(1+β2)×P×R​ 其中 β \beta β度量了查全率对查准率的相对重要性, β > 1 \beta >1 β>1时,查全率影响更大, β < 1 \beta > src = torch.arange(1, 11).reshape((2, 5)) >>> src tensor([[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]]) >>> index = torch.tensor([[0, 1, 2, 0]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src) tensor([[1, 0, 0, 4, 0], [0, 2, 0, 0, 0], [0, 0, 3, 0, 0]])

该代码样例中对一个维度为(3,5)的全0张量进行了计算,参数dim=0,结果张量中,值发生变化的位置分别是(0,0)、(1,1)、(2,2)、(0,3)。细心的朋友应该能发现,这四个位置的第一个维度(dim=0)刚好是张量index的各个元素,而第二个维度则是index中对应元素的另一个维度的值。再看另一个例子

>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]]) >>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src) tensor([[1, 2, 3, 0, 0], [6, 7, 0, 0, 8], [0, 0, 0, 0, 0]])

该代码中值发生变化的位置是(0,0)、(0,1)、(0,2)、(1,0)、(1,1)、(1,4),同样的,这几个位置的第二个维度(dim=1)是index中的各个元素,而第一个维度则是index中对应元素的另一个维度的值。在涉及到高维张量的时候,这个操作的空间意义我们可能难以想想,但是记住这点而不去追求理解其空间意义,再高的维度,也能很好地理解这个计算的具体操作。

4、PyTorch实战与代码解析

接下来就是在PyTorch实现对查准率、查全率与F1的计算了,在该实例中,我们用到了scatter_()函数,首先来看完整的数据集测试过程代码,便于各位理解与取用,然后再对具体的计算代码进行阐述。

def test(valid_queue, net, criterion): net.eval() test_loss = 0 target_num = torch.zeros((1, n_classes)) # n_classes为分类任务类别数量 predict_num = torch.zeros((1, n_classes)) acc_num = torch.zeros((1, n_classes)) with torch.no_grad(): for step, (inputs, targets) in enumerate(valid_queue): inputs, targets = inputs.to(device), targets.to(device) outputs, _ = net(inputs) loss = criterion(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) pre_mask = torch.zeros(outputs.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.) predict_num += pre_mask.sum(0) # 得到数据中每类的预测量 tar_mask = torch.zeros(outputs.size()).scatter_(1, targets.data.cpu().view(-1, 1), 1.) target_num += tar_mask.sum(0) # 得到数据中每类的数量 acc_mask = pre_mask * tar_mask acc_num += acc_mask.sum(0) # 得到各类别分类正确的样本数量 recall = acc_num / target_num precision = acc_num / predict_num F1 = 2 * recall * precision / (recall + precision) accuracy = 100. * acc_num.sum(1) / target_num.sum(1) print('Test Acc {}, recal {}, precision {}, F1-score {}'.format(accuracy, recall, precision, F1)) return accuracy

首先看下面这行代码,产生一个大小为(batch_size , n_classes)的全0张量,然后将predicted的维度变成(batch_size,1),每个元素都代表的是其分类对应的编号,通过scatter_()函数,将1写入了全0张量中的对应位置。得到的张量pre_mask就是每次预测结果的one-hot编码。

pre_mask = torch.zeros(outputs.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.)

然后将pre_mask的第一个维度上的所有值进行加和,就得到了一个维度为(1 ,n_classes)的张量,该张量中的每个元素都是预测结果中对应的类的数量,对其进行累加,便得到了整个测试数据集的预测结果predict_num。

predict_num += pre_mask.sum(0)

tar_mask与target_num同理,此处不再赘述。

acc_mask = pre_mask * tar_mask

然后通过pre_mask * tar_mask,得到了一个表示分类正确的样本与对应类别的矩阵acc_mask,其维度为(batch_size , n_classes),对于其中一个元素acc_mask[i][j]=1,表示这个batch_size 中的第i个样本分类正确,类别为j。

acc_num += acc_mask.sum(0)

将acc_mask的第一个维度上的所有值进行加和,便可得到该batch_size数据中每个类别预测正确的数量,累加即可得到整个验证数据集中各类正确预测的样本数。

recall = acc_num / target_num precision = acc_num / predict_num F1 = 2 * recall * precision / (recall + precision) accuracy = 100. * acc_num.sum(1) / target_num.sum(1)

然后就可以计算各个指标了,很好理解,就不再解释了。

希望对你有所帮助,也欢迎在评论区提出你的想法与意见。

5、参考

[1].《机器学习》,周志华

[2].https://blog.csdn.net/qq_16234613/article/details/80039080

[3].https://zhuanlan.zhihu.com/p/46204175

[4].https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html?highlight=scatter_

rticle/details/80039080

[3].https://zhuanlan.zhihu.com/p/46204175

[4].https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html?highlight=scatter_



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3